package com.tuniu.agents.advisor;

import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import java.util.*;

/**
 * Streaming Tool Output Manager
 * The main purpose is to convert streaming tool output into non-streaming.
 */
public class StreamingToolCallingManager implements ToolCallingManager {

    private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
            = ObservationRegistry.NOOP;

    private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
            = new DelegatingToolCallbackResolver(List.of());

    private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
            = DefaultToolExecutionExceptionProcessor.builder().build();

// @formatter:on

    private final ObservationRegistry observationRegistry;

    private final ToolCallbackResolver toolCallbackResolver;

    private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;
    
    public StreamingToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
        Assert.notNull(observationRegistry, "observationRegistry cannot be null");
        Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
        Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

        this.observationRegistry = observationRegistry;
        this.toolCallbackResolver = toolCallbackResolver;
        this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
    }

    @Override
    public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
        Assert.notNull(chatOptions, "chatOptions cannot be null");

        List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
        for (String toolName : chatOptions.getToolNames()) {
            // Skip the tool if it is already present in the request toolCallbacks.
            // That might happen if a tool is defined in the options
            // both as a ToolCallback and as a tool name.
            if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
                continue;
            }
            FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }
            toolCallbacks.add(toolCallback);
        }

        return toolCallbacks.stream().map(functionCallback -> {
            if (functionCallback instanceof ToolCallback toolCallback) {
                return toolCallback.getToolDefinition();
            }
            else {
                return ToolDefinition.builder()
                        .name(functionCallback.getName())
                        .description(functionCallback.getDescription())
                        .inputSchema(functionCallback.getInputTypeSchema())
                        .build();
            }
        }).toList();
    }

    @Override
    public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
        Optional<Generation> first = chatResponse.getResults()
                .stream()
                .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
                .findFirst();
        Generation generation = first.orElseThrow();

        String args = generation.getOutput().getToolCalls().stream().map(AssistantMessage.ToolCall::arguments).reduce("", (s, s2) -> s + s2);
        AssistantMessage.ToolCall toolcall = generation.getOutput().getToolCalls().get(0);
        AssistantMessage.ToolCall newToolCall = new AssistantMessage.ToolCall(toolcall.id(), toolcall.type(), toolcall.name(), args);
        List<AssistantMessage.ToolCall> newToolCalls = List.of(newToolCall);

        AssistantMessage assistantMessage = new AssistantMessage(generation.getOutput().getText(), generation.getOutput().getMetadata(), newToolCalls, generation.getOutput().getMedia());

        ToolContext toolContext = buildToolContext(prompt, assistantMessage);

        InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
                toolContext);

        List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
                assistantMessage, internalToolExecutionResult.toolResponseMessage());

        return ToolExecutionResult.builder()
                .conversationHistory(conversationHistory)
                .returnDirect(internalToolExecutionResult.returnDirect())
                .build();
    }

    private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) {
        Map<String, Object> toolContextMap = Map.of();

        if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions
                && !CollectionUtils.isEmpty(functionOptions.getToolContext())) {
            toolContextMap = new HashMap<>(functionOptions.getToolContext());

            List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
            messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
                    assistantMessage.getToolCalls()));

            toolContextMap.put(ToolContext.TOOL_CALL_HISTORY,
                    buildConversationHistoryBeforeToolExecution(prompt, assistantMessage));
        }

        return new ToolContext(toolContextMap);
    }

    private static List<Message> buildConversationHistoryBeforeToolExecution(Prompt prompt,
                                                                             AssistantMessage assistantMessage) {
        List<Message> messageHistory = new ArrayList<>(prompt.copy().getInstructions());
        messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
                assistantMessage.getToolCalls()));
        return messageHistory;
    }

    private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
                                                        ToolContext toolContext) {
        List<FunctionCallback> toolCallbacks = List.of();
        if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
            toolCallbacks = toolCallingChatOptions.getToolCallbacks();
        }
        else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
            toolCallbacks = functionOptions.getFunctionCallbacks();
        }

        List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

        Boolean returnDirect = null;

        for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

            String toolName = toolCall.name();
            String toolInputArguments = toolCall.arguments();

            FunctionCallback toolCallback = toolCallbacks.stream()
                    .filter(tool -> toolName.equals(tool.getName()))
                    .findFirst()
                    .orElseGet(() -> toolCallbackResolver.resolve(toolName));

            if (toolCallback == null) {
                throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
            }

            if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
                returnDirect = callback.getToolMetadata().returnDirect();
            }
            else if (toolCallback instanceof ToolCallback callback) {
                returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
            }
            else if (returnDirect == null) {
                // This is a temporary solution to ensure backward compatibility with
                // FunctionCallback.
                // TODO: remove this block when FunctionCallback is removed.
                returnDirect = false;
            }

            String toolResult;
            try {
                toolResult = toolCallback.call(toolInputArguments, toolContext);
            }
            catch (ToolExecutionException ex) {
                toolResult = toolExecutionExceptionProcessor.process(ex);
            }

            toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
        }

        return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
    }

    private List<Message> buildConversationHistoryAfterToolExecution(List<Message> previousMessages,
                                                                     AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
        List<Message> messages = new ArrayList<>(previousMessages);
        messages.add(assistantMessage);
        messages.add(toolResponseMessage);
        return messages;
    }

    private record InternalToolExecutionResult(ToolResponseMessage toolResponseMessage, boolean returnDirect) {
    }
}
